import json
import numpy as np
import scipy.io as sio
import os


def Format_Pred(pred_file):
    orig_file = json.load(open(pred_file, 'r'))
    hoi_label = {(5, 5): 1, (5, 18): 2, (5, 26): 3, (5, 31): 4, (5, 42): 5, (5, 53): 6, (5, 77): 7, (5, 88): 8,
                 (5, 112): 9, (5, 58): 10, (2, 9): 11, (2, 37): 12, (2, 42): 13, (2, 44): 14, (2, 38): 15, (2, 63): 16,
                 (2, 72): 17, (2, 76): 18, (2, 77): 19, (2, 88): 20, (2, 99): 21, (2, 111): 22, (2, 112): 23,
                 (2, 58): 24, (16, 11): 25, (16, 27): 26, (16, 37): 27, (16, 66): 28, (16, 75): 29, (16, 113): 30,
                 (16, 58): 31, (9, 5): 32, (9, 22): 33, (9, 26): 34, (9, 42): 35, (9, 44): 36, (9, 48): 37, (9, 76): 38,
                 (9, 77): 39, (9, 78): 40, (9, 80): 41, (9, 88): 42, (9, 94): 43, (9, 106): 44, (9, 112): 45,
                 (9, 58): 46, (44, 9): 47, (44, 21): 48, (44, 37): 49, (44, 42): 50, (44, 49): 51, (44, 59): 52,
                 (44, 70): 53, (44, 58): 54, (6, 5): 55, (6, 18): 56, (6, 22): 57, (6, 26): 58, (6, 42): 59,
                 (6, 53): 60, (6, 77): 61, (6, 88): 62, (6, 112): 63, (6, 114): 64, (6, 58): 65, (3, 5): 66,
                 (3, 18): 67, (3, 22): 68, (3, 39): 69, (3, 42): 70, (3, 44): 71, (3, 53): 72, (3, 63): 73, (3, 77): 74,
                 (3, 112): 75, (3, 58): 76, (17, 23): 77, (17, 27): 78, (17, 37): 79, (17, 40): 80, (17, 46): 81,
                 (17, 66): 82, (17, 81): 83, (17, 112): 84, (17, 11): 85, (17, 58): 86, (62, 9): 87, (62, 37): 88,
                 (62, 50): 89, (62, 88): 90, (62, 94): 91, (62, 58): 92, (63, 9): 93, (63, 50): 94, (63, 88): 95,
                 (63, 58): 96, (21, 27): 97, (21, 35): 98, (21, 37): 99, (21, 40): 100, (21, 46): 101, (21, 47): 102,
                 (21, 56): 103, (21, 66): 104, (21, 77): 105, (21, 111): 106, (21, 58): 107, (67, 13): 108,
                 (67, 25): 109, (67, 87): 110, (67, 58): 111, (18, 9): 112, (18, 23): 113, (18, 27): 114, (18, 34): 115,
                 (18, 37): 116, (18, 39): 117, (18, 40): 118, (18, 42): 119, (18, 46): 120, (18, 66): 121,
                 (18, 79): 122, (18, 81): 123, (18, 99): 124, (18, 108): 125, (18, 111): 126, (18, 112): 127,
                 (18, 11): 128, (18, 58): 129, (19, 27): 130, (19, 34): 131, (19, 37): 132, (19, 40): 133,
                 (19, 44): 134, (19, 46): 135, (19, 53): 136, (19, 38): 137, (19, 66): 138, (19, 73): 139,
                 (19, 77): 140, (19, 79): 141, (19, 99): 142, (19, 108): 143, (19, 111): 144, (19, 112): 145,
                 (19, 58): 146, (4, 37): 147, (4, 42): 148, (4, 44): 149, (4, 38): 150, (4, 63): 151, (4, 72): 152,
                 (4, 73): 153, (4, 77): 154, (4, 88): 155, (4, 99): 156, (4, 109): 157, (4, 111): 158, (4, 112): 159,
                 (4, 58): 160, (1, 9): 161, (1, 32): 162, (1, 37): 163, (1, 40): 164, (1, 46): 165, (1, 93): 166,
                 (1, 101): 167, (1, 103): 168, (1, 49): 169, (1, 58): 170, (64, 9): 171, (64, 37): 172, (64, 39): 173,
                 (64, 58): 174, (20, 9): 175, (20, 27): 176, (20, 35): 177, (20, 37): 178, (20, 40): 179, (20, 46): 180,
                 (20, 66): 181, (20, 77): 182, (20, 84): 183, (20, 111): 184, (20, 112): 185, (20, 58): 186,
                 (7, 5): 187, (7, 22): 188, (7, 26): 189, (7, 53): 190, (7, 77): 191, (7, 88): 192, (7, 112): 193,
                 (7, 58): 194, (72, 14): 195, (72, 76): 196, (72, 113): 197, (72, 58): 198, (53, 8): 199, (53, 16): 200,
                 (53, 24): 201, (53, 37): 202, (53, 42): 203, (53, 65): 204, (53, 67): 205, (53, 90): 206,
                 (53, 112): 207, (53, 58): 208, (27, 9): 209, (27, 37): 210, (27, 42): 211, (27, 59): 212,
                 (27, 115): 213, (27, 58): 214, (52, 8): 215, (52, 9): 216, (52, 16): 217, (52, 24): 218, (52, 37): 219,
                 (52, 42): 220, (52, 65): 221, (52, 67): 222, (52, 90): 223, (52, 58): 224, (39, 6): 225, (39, 9): 226,
                 (39, 37): 227, (39, 85): 228, (39, 100): 229, (39, 105): 230, (39, 116): 231, (39, 58): 232,
                 (40, 37): 233, (40, 115): 234, (40, 58): 235, (23, 27): 236, (23, 41): 237, (23, 113): 238,
                 (23, 58): 239, (65, 13): 240, (65, 50): 241, (65, 88): 242, (65, 58): 243, (15, 42): 244,
                 (15, 50): 245, (15, 88): 246, (15, 58): 247, (84, 9): 248, (84, 37): 249, (84, 59): 250, (84, 74): 251,
                 (84, 58): 252, (51, 37): 253, (51, 97): 254, (51, 112): 255, (51, 49): 256, (51, 58): 257,
                 (56, 16): 258, (56, 24): 259, (56, 37): 260, (56, 90): 261, (56, 97): 262, (56, 112): 263,
                 (56, 58): 264, (61, 4): 265, (61, 9): 266, (61, 16): 267, (61, 24): 268, (61, 37): 269, (61, 52): 270,
                 (61, 55): 271, (61, 68): 272, (61, 58): 273, (57, 9): 274, (57, 15): 275, (57, 16): 276, (57, 24): 277,
                 (57, 37): 278, (57, 65): 279, (57, 90): 280, (57, 97): 281, (57, 112): 282, (57, 58): 283,
                 (77, 9): 284, (77, 37): 285, (77, 74): 286, (77, 76): 287, (77, 102): 288, (77, 104): 289,
                 (77, 58): 290, (85, 12): 291, (85, 37): 292, (85, 76): 293, (85, 83): 294, (85, 58): 295, (47, 9): 296,
                 (47, 21): 297, (47, 37): 298, (47, 42): 299, (47, 70): 300, (47, 86): 301, (47, 90): 302,
                 (47, 28): 303, (47, 112): 304, (47, 58): 305, (60, 8): 306, (60, 9): 307, (60, 24): 308, (60, 37): 309,
                 (60, 55): 310, (60, 68): 311, (60, 90): 312, (60, 58): 313, (22, 27): 314, (22, 37): 315,
                 (22, 39): 316, (22, 40): 317, (22, 46): 318, (22, 38): 319, (22, 66): 320, (22, 77): 321,
                 (22, 111): 322, (22, 112): 323, (22, 113): 324, (22, 58): 325, (11, 40): 326, (11, 42): 327,
                 (11, 59): 328, (11, 62): 329, (11, 58): 330, (48, 37): 331, (48, 51): 332, (48, 96): 333,
                 (48, 49): 334, (48, 112): 335, (48, 58): 336, (34, 3): 337, (34, 10): 338, (34, 37): 339,
                 (34, 91): 340, (34, 105): 341, (34, 58): 342, (25, 27): 343, (25, 46): 344, (25, 66): 345,
                 (25, 77): 346, (25, 113): 347, (25, 58): 348, (89, 37): 349, (89, 60): 350, (89, 76): 351,
                 (89, 58): 352, (31, 9): 353, (31, 37): 354, (31, 42): 355, (31, 58): 356, (58, 9): 357, (58, 15): 358,
                 (58, 16): 359, (58, 24): 360, (58, 37): 361, (58, 55): 362, (58, 58): 363, (76, 9): 364, (76, 13): 365,
                 (76, 37): 366, (76, 110): 367, (76, 58): 368, (38, 2): 369, (38, 9): 370, (38, 31): 371, (38, 37): 372,
                 (38, 42): 373, (38, 48): 374, (38, 71): 375, (38, 58): 376, (49, 17): 377, (49, 37): 378,
                 (49, 96): 379, (49, 112): 380, (49, 116): 381, (49, 49): 382, (49, 58): 383, (73, 37): 384,
                 (73, 59): 385, (73, 74): 386, (73, 76): 387, (73, 110): 388, (73, 58): 389, (78, 13): 390,
                 (78, 59): 391, (78, 60): 392, (78, 58): 393, (74, 14): 394, (74, 37): 395, (74, 76): 396,
                 (74, 58): 397, (55, 8): 398, (55, 16): 399, (55, 24): 400, (55, 37): 401, (55, 42): 402, (55, 65): 403,
                 (55, 67): 404, (55, 92): 405, (55, 112): 406, (55, 58): 407, (79, 13): 408, (79, 37): 409,
                 (79, 42): 410, (79, 59): 411, (79, 76): 412, (79, 60): 413, (79, 58): 414, (14, 12): 415,
                 (14, 64): 416, (14, 76): 417, (14, 58): 418, (59, 8): 419, (59, 9): 420, (59, 15): 421, (59, 16): 422,
                 (59, 24): 423, (59, 37): 424, (59, 55): 425, (59, 68): 426, (59, 89): 427, (59, 90): 428,
                 (59, 58): 429, (82, 13): 430, (82, 37): 431, (82, 57): 432, (82, 59): 433, (82, 58): 434,
                 (75, 37): 435, (75, 69): 436, (75, 100): 437, (75, 58): 438, (54, 9): 439, (54, 15): 440,
                 (54, 16): 441, (54, 24): 442, (54, 37): 443, (54, 55): 444, (54, 58): 445, (87, 17): 446,
                 (87, 37): 447, (87, 59): 448, (87, 58): 449, (81, 13): 450, (81, 76): 451, (81, 112): 452,
                 (81, 58): 453, (41, 9): 454, (41, 29): 455, (41, 33): 456, (41, 37): 457, (41, 44): 458, (41, 68): 459,
                 (41, 77): 460, (41, 88): 461, (41, 94): 462, (41, 58): 463, (35, 1): 464, (35, 9): 465, (35, 37): 466,
                 (35, 42): 467, (35, 44): 468, (35, 68): 469, (35, 76): 470, (35, 77): 471, (35, 94): 472,
                 (35, 115): 473, (35, 58): 474, (36, 1): 475, (36, 9): 476, (36, 33): 477, (36, 37): 478, (36, 44): 479,
                 (36, 77): 480, (36, 94): 481, (36, 115): 482, (36, 58): 483, (50, 37): 484, (50, 49): 485,
                 (50, 112): 486, (50, 86): 487, (50, 58): 488, (37, 3): 489, (37, 9): 490, (37, 10): 491, (37, 20): 492,
                 (37, 36): 493, (37, 37): 494, (37, 42): 495, (37, 45): 496, (37, 68): 497, (37, 82): 498,
                 (37, 85): 499, (37, 91): 500, (37, 105): 501, (37, 58): 502, (13, 37): 503, (13, 95): 504,
                 (13, 98): 505, (13, 58): 506, (33, 9): 507, (33, 19): 508, (33, 37): 509, (33, 40): 510, (33, 53): 511,
                 (33, 59): 512, (33, 61): 513, (33, 68): 514, (33, 117): 515, (33, 58): 516, (42, 9): 517,
                 (42, 19): 518, (42, 37): 519, (42, 42): 520, (42, 44): 521, (42, 50): 522, (42, 53): 523,
                 (42, 77): 524, (42, 94): 525, (42, 88): 526, (42, 112): 527, (42, 58): 528, (88, 9): 529,
                 (88, 37): 530, (88, 40): 531, (88, 46): 532, (88, 58): 533, (43, 9): 534, (43, 37): 535, (43, 42): 536,
                 (43, 100): 537, (43, 58): 538, (32, 1): 539, (32, 16): 540, (32, 37): 541, (32, 42): 542,
                 (32, 71): 543, (32, 106): 544, (32, 115): 545, (32, 58): 546, (80, 37): 547, (80, 60): 548,
                 (80, 76): 549, (80, 58): 550, (70, 13): 551, (70, 30): 552, (70, 59): 553, (70, 76): 554,
                 (70, 88): 555, (70, 94): 556, (70, 112): 557, (70, 58): 558, (90, 7): 559, (90, 37): 560,
                 (90, 112): 561, (90, 58): 562, (10, 43): 563, (10, 76): 564, (10, 95): 565, (10, 98): 566,
                 (10, 58): 567, (8, 18): 568, (8, 22): 569, (8, 42): 570, (8, 53): 571, (8, 76): 572, (8, 77): 573,
                 (8, 88): 574, (8, 112): 575, (8, 58): 576, (28, 9): 577, (28, 37): 578, (28, 54): 579, (28, 59): 580,
                 (28, 76): 581, (28, 83): 582, (28, 95): 583, (28, 58): 584, (86, 37): 585, (86, 55): 586,
                 (86, 62): 587, (86, 58): 588, (46, 28): 589, (46, 37): 590, (46, 86): 591, (46, 107): 592,
                 (46, 49): 593, (46, 112): 594, (46, 58): 595, (24, 27): 596, (24, 37): 597, (24, 66): 598,
                 (24, 113): 599, (24, 58): 600}
    out_pred = {}
    for annot in orig_file:
        annot_bbox = annot['predictions']
        annot_hoi = annot['hoi_prediction']
        img_id = int((annot['file_name'].split('.')[0]).split('_')[-1])
        for hoi in annot_hoi:
            sub_bbox = annot_bbox[hoi['subject_id']]
            obj_bbox = annot_bbox[hoi['object_id']]
            verb_cls = hoi['category_id']
            score = hoi['score']
            triplet = (obj_bbox['category_id'], verb_cls)
            this_out = {'img_id': img_id, 'human_box': sub_bbox['bbox'], 'object_box': obj_bbox['bbox'], 'score': score}
            if triplet not in hoi_label.keys():
                print(triplet)
                continue
            hoi_cls = int(hoi_label[triplet])
            if hoi_cls not in out_pred.keys():
                out_pred[hoi_cls] = []
            out_pred[hoi_cls].append(this_out)
    return out_pred


def save_HICO(HICO, HICO_dir, classid, begin, finish):
    all_boxes = []
    for i in range(begin, finish + 1):
        total = []
        score = []
        if str(i) in HICO.keys():
            for element in HICO[str(i)]:
                temp = []
                temp.append(element['human_box'])  # Human box
                temp.append(element['object_box'])  # Object box
                temp.append(element['img_id'])  # image id
                temp.append(int(i - begin))  # action id (0-599)
                temp.append(element['score'] * 1000)
                total.append(temp)
                score.append(element['score'] * 1000)

            idx = np.argsort(score, axis=0)[::-1]
            for i_idx in range(min(len(idx), 19999)):
                all_boxes.append(total[idx[i_idx]])
        else:
            print(i)
    savefile = HICO_dir + 'detections_' + str(classid).zfill(2) + '.mat'
    sio.savemat(savefile, {'all_boxes': all_boxes})


def Generate_HICO_detection(output_file, HICO_dir):
    if not os.path.exists(HICO_dir):
        os.makedirs(HICO_dir)

    # Remove previous results
    filelist = [f for f in os.listdir(HICO_dir)]
    for f in filelist:
        os.remove(os.path.join(HICO_dir, f))

    HICO = Format_Pred(output_file)

    save_HICO(HICO, HICO_dir, 1, 161, 170)  # 1 person
    save_HICO(HICO, HICO_dir, 2, 11, 24)  # 2 bicycle
    save_HICO(HICO, HICO_dir, 3, 66, 76)  # 3 car
    save_HICO(HICO, HICO_dir, 4, 147, 160)  # 4 motorcycle
    save_HICO(HICO, HICO_dir, 5, 1, 10)  # 5 airplane
    save_HICO(HICO, HICO_dir, 6, 55, 65)  # 6 bus
    save_HICO(HICO, HICO_dir, 7, 187, 194)  # 7 train
    save_HICO(HICO, HICO_dir, 8, 568, 576)  # 8 truck
    save_HICO(HICO, HICO_dir, 9, 32, 46)  # 9 boat
    save_HICO(HICO, HICO_dir, 10, 563, 567)  # 10 traffic light
    save_HICO(HICO, HICO_dir, 11, 326, 330)  # 11 fire_hydrant
    save_HICO(HICO, HICO_dir, 12, 503, 506)  # 12 stop_sign
    save_HICO(HICO, HICO_dir, 13, 415, 418)  # 13 parking_meter
    save_HICO(HICO, HICO_dir, 14, 244, 247)  # 14 bench
    save_HICO(HICO, HICO_dir, 15, 25, 31)  # 15 bird
    save_HICO(HICO, HICO_dir, 16, 77, 86)  # 16 cat
    save_HICO(HICO, HICO_dir, 17, 112, 129)  # 17 dog
    save_HICO(HICO, HICO_dir, 18, 130, 146)  # 18 horse
    save_HICO(HICO, HICO_dir, 19, 175, 186)  # 19 sheep
    save_HICO(HICO, HICO_dir, 20, 97, 107)  # 20 cow
    save_HICO(HICO, HICO_dir, 21, 314, 325)  # 21 elephant
    save_HICO(HICO, HICO_dir, 22, 236, 239)  # 22 bear
    save_HICO(HICO, HICO_dir, 23, 596, 600)  # 23 zebra
    save_HICO(HICO, HICO_dir, 24, 343, 348)  # 24 giraffe
    save_HICO(HICO, HICO_dir, 25, 209, 214)  # 25 backpack
    save_HICO(HICO, HICO_dir, 26, 577, 584)  # 26 umbrella
    save_HICO(HICO, HICO_dir, 27, 353, 356)  # 27 handbag
    save_HICO(HICO, HICO_dir, 28, 539, 546)  # 28 tie
    save_HICO(HICO, HICO_dir, 29, 507, 516)  # 29 suitcase
    save_HICO(HICO, HICO_dir, 30, 337, 342)  # 30 Frisbee
    save_HICO(HICO, HICO_dir, 31, 464, 474)  # 31 skis
    save_HICO(HICO, HICO_dir, 32, 475, 483)  # 32 snowboard
    save_HICO(HICO, HICO_dir, 33, 489, 502)  # 33 sports_ball
    save_HICO(HICO, HICO_dir, 34, 369, 376)  # 34 kite
    save_HICO(HICO, HICO_dir, 35, 225, 232)  # 35 baseball_bat
    save_HICO(HICO, HICO_dir, 36, 233, 235)  # 36 baseball_glove
    save_HICO(HICO, HICO_dir, 37, 454, 463)  # 37 skateboard
    save_HICO(HICO, HICO_dir, 38, 517, 528)  # 38 surfboard
    save_HICO(HICO, HICO_dir, 39, 534, 538)  # 39 tennis_racket
    save_HICO(HICO, HICO_dir, 40, 47, 54)  # 40 bottle
    save_HICO(HICO, HICO_dir, 41, 589, 595)  # 41 wine_glass
    save_HICO(HICO, HICO_dir, 42, 296, 305)  # 42 cup
    save_HICO(HICO, HICO_dir, 43, 331, 336)  # 43 fork
    save_HICO(HICO, HICO_dir, 44, 377, 383)  # 44 knife
    save_HICO(HICO, HICO_dir, 45, 484, 488)  # 45 spoon
    save_HICO(HICO, HICO_dir, 46, 253, 257)  # 46 bowl
    save_HICO(HICO, HICO_dir, 47, 215, 224)  # 47 banana
    save_HICO(HICO, HICO_dir, 48, 199, 208)  # 48 apple
    save_HICO(HICO, HICO_dir, 49, 439, 445)  # 49 sandwich
    save_HICO(HICO, HICO_dir, 50, 398, 407)  # 50 orange
    save_HICO(HICO, HICO_dir, 51, 258, 264)  # 51 broccoli
    save_HICO(HICO, HICO_dir, 52, 274, 283)  # 52 carrot
    save_HICO(HICO, HICO_dir, 53, 357, 363)  # 53 hot_dog
    save_HICO(HICO, HICO_dir, 54, 419, 429)  # 54 pizza
    save_HICO(HICO, HICO_dir, 55, 306, 313)  # 55 donut
    save_HICO(HICO, HICO_dir, 56, 265, 273)  # 56 cake
    save_HICO(HICO, HICO_dir, 57, 87, 92)  # 57 chair
    save_HICO(HICO, HICO_dir, 58, 93, 96)  # 58 couch
    save_HICO(HICO, HICO_dir, 59, 171, 174)  # 59 potted_plant
    save_HICO(HICO, HICO_dir, 60, 240, 243)  # 60 bed
    save_HICO(HICO, HICO_dir, 61, 108, 111)  # 61 dining_table
    save_HICO(HICO, HICO_dir, 62, 551, 558)  # 62 toilet
    save_HICO(HICO, HICO_dir, 63, 195, 198)  # 63 TV
    save_HICO(HICO, HICO_dir, 64, 384, 389)  # 64 laptop
    save_HICO(HICO, HICO_dir, 65, 394, 397)  # 65 mouse
    save_HICO(HICO, HICO_dir, 66, 435, 438)  # 66 remote
    save_HICO(HICO, HICO_dir, 67, 364, 368)  # 67 keyboard
    save_HICO(HICO, HICO_dir, 68, 284, 290)  # 68 cell_phone
    save_HICO(HICO, HICO_dir, 69, 390, 393)  # 69 microwave
    save_HICO(HICO, HICO_dir, 70, 408, 414)  # 70 oven
    save_HICO(HICO, HICO_dir, 71, 547, 550)  # 71 toaster
    save_HICO(HICO, HICO_dir, 72, 450, 453)  # 72 sink
    save_HICO(HICO, HICO_dir, 73, 430, 434)  # 73 refrigerator
    save_HICO(HICO, HICO_dir, 74, 248, 252)  # 74 book
    save_HICO(HICO, HICO_dir, 75, 291, 295)  # 75 clock
    save_HICO(HICO, HICO_dir, 76, 585, 588)  # 76 vase
    save_HICO(HICO, HICO_dir, 77, 446, 449)  # 77 scissors
    save_HICO(HICO, HICO_dir, 78, 529, 533)  # 78 teddy_bear
    save_HICO(HICO, HICO_dir, 79, 349, 352)  # 79 hair_drier
    save_HICO(HICO, HICO_dir, 80, 559, 562)  # 80 toothbrush


if __name__ == '__main__':
    Generate_HICO_detection('best_predictions.json', './ppdm_results/')
